rm(list=ls())
library(tidyverse)
library(dplyr)
library(ggplot2)
library(geometry)
library(plot3D)
library(rgl)
library(zoo)
setwd(dirname(rstudioapi::getActiveDocumentContext()$path))


full <- readRDS('Data/Prediction/img5_weighted_opencv_cov2f_nofw_level.rds') %>% 
  drop_na(race.mt, race.s, race.gs, race.ff, race.lstm_wk, race.hybrid, race.fbisg)
simplex <- function(n) {
  qr.Q(qr(matrix(1, nrow=n)) ,complete = TRUE)[,-1]
}
tetra <- simplex(4)

# Figure 1
df3D_ls <- list()
for (j in c('hybrid', 'fbisg', 'gs', 's', 'lstm_wk', 'ff')){
  for (true_only in c('true', 'false')){
    if (true_only == 'true'){
      xx = full[full$race.mt == full[[paste0('race.',j)]],
                 c(paste0(j,'.',c('whi','bla','his','asi')), paste0('race.',j))] %>% distinct()
    }
    if (true_only == 'false'){
      xx = full[full$race.mt != full[[paste0('race.',j)]],
                 c(paste0(j,'.',c('whi','bla','his','asi')), paste0('race.',j))] %>% distinct()
    }
    pdf <- xx
    covk = apply(pdf[,1:4], 1, max)
    pdf = t(apply(pdf[1:4], 1, function(x) x / sum(x))) #convert absolute data to relative
    
    # Convert barycentric coordinates (4D) to cartesian coordinates (3D)
    df3D_ls[[true_only]][[j]] <- bary2cart(tetra, pdf)
    df3D_ls[['max']][[true_only]][[j]] <- covk
  }
}

coln = c("W", 'B', "H", "A")

png('Figures/Figure 1 (a).png', res = 600, height = 1900, width = 9500)
par(mar = c(0,0,0,0), oma = c(0,0,0,0), xpd = TRUE, mfrow = c(1,6), mex = 2)
for (i in c('s', 'gs', 'fbisg', 'lstm_wk', 'ff', 'hybrid')){
  scatter3D(df3D_ls$true[[i]][,1], df3D_ls$true[[i]][,2], df3D_ls$true[[i]][,3],
            colvar = df3D_ls$max$true[[i]],
            # colkey = list(side = 1, side.clab = 1,clog=T,cex.axis=1.5,cex.clab=2),
            colkey = FALSE,
            clim = 0:1,
            side.lab = 2, line.clab = 1,
            xlim = c(-0.5, 0.5), 
            ylim = c(-0.4, 0.35),
            zlim = c(-0.5, 0.5),
            pch = 16, box = F,  theta = 50, phi = 40, alpha = 0.8, cex = 1, ellipsoid = T, ticktype = 'detailed')
  lines3D(tetra[c(1,2,3,4,1,3,1,2,4),1],
          tetra[c(1,2,3,4,1,3,1,2,4),2],
          tetra[c(1,2,3,4,1,3,1,2,4),3],
          col = "grey", add = TRUE)
  text3D(tetra[2:4,1], tetra[2:4,2]+0.05, tetra[2:4,3], theta = 50, phi = 40, coln[-1], add = TRUE, cex = 1.8)
  text3D(tetra[1,1], tetra[1,2]-0.15, tetra[1,3]+0.1, theta = 50, phi = 40, coln[1], add = TRUE, cex = 1.8) 
}
dev.off()

png('Figures/Figure 1 (b).png', res = 600, height = 1900, width = 9500)
par(mar = c(0,0,0,0), oma = c(0,0,0,0), xpd = TRUE, mfrow = c(1,6), mex = 2)
for (i in c('s', 'gs', 'fbisg', 'lstm_wk', 'ff', 'hybrid')){
  scatter3D(df3D_ls$false[[i]][,1], df3D_ls$false[[i]][,2], df3D_ls$false[[i]][,3],
            colvar = df3D_ls$max$false[[i]],
            # colkey = list(side = 1, side.clab = 1,clog=T,cex.axis=1.5,cex.clab=2),
            colkey = FALSE,
            clim = 0:1,
            side.lab = 2, line.clab = 1,
            xlim = c(-0.5, 0.5), 
            ylim = c(-0.4, 0.35),
            zlim = c(-0.5, 0.5),
            pch = 16, box = F,  theta = 50, phi = 40, alpha = 0.8, cex = 1, ellipsoid = T, ticktype = 'detailed')
  lines3D(tetra[c(1,2,3,4,1,3,1,2,4),1],
          tetra[c(1,2,3,4,1,3,1,2,4),2],
          tetra[c(1,2,3,4,1,3,1,2,4),3],
          col = "grey", add = TRUE)
  text3D(tetra[2:4,1], tetra[2:4,2]+0.05, tetra[2:4,3], theta = 50, phi = 40, coln[-1], add = TRUE, cex = 1.8)
  text3D(tetra[1,1], tetra[1,2]-0.15, tetra[1,3]+0.1, theta = 50, phi = 40, coln[1], add = TRUE, cex = 1.8) 
}
dev.off()


# Figure 1 Legend
i <- 's'
png('Figures/Figure 1 Legend.png', res = 600, height = 3000, width = 8000)
par(mar = c(4,8,0,8), oma = c(0,0,0,0), xpd = TRUE)
scatter3D(df3D_ls$false[[i]][,1], df3D_ls$false[[i]][,2], df3D_ls$false[[i]][,3],
          colvar = df3D_ls$max$false[[i]],
          colkey = list(side = 1, side.clab = 1,clog=T,cex.axis=1.5,cex.clab=2),
          clim = 0:1,
          clab='Predicted Probability',
          side.lab = 2, line.clab = 1,
          xlim = c(-0.5, 0.5), 
          ylim = c(-0.4, 0.35),
          zlim = c(-0.5, 0.5),
          pch = 16, box = F,  theta = 50, phi = 40, alpha = 0.8, cex = 1, ellipsoid = T, ticktype = 'detailed')
dev.off()


